import pathlib
import argparse
import os
import yaml
from typing import List, Dict, Union, Tuple, Sequence
from tools.codegen.gen import FileManager, get_grouped_native_functions, LineLoader, parse_native_yaml
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
                                 NativeFunction, NativeFunctionsGroup, OperatorName,
                                 ExternalBackendMetadata, assert_never)
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import Target, concatMap
import tools.codegen.dest as dest

def parse_backend_yaml(
        backend_yaml_path: str,
        grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
    with open(backend_yaml_path, 'r') as f:
        yaml_values = yaml.load(f, Loader=LineLoader)
    assert isinstance(yaml_values, dict)

    cpp_namespace = yaml_values.pop('cpp_namespace')
    backend = yaml_values.pop('backend')

    supported = yaml_values.pop('supported', [])
    assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}'
    supported_autograd = yaml_values.pop('autograd', [])
    assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'

    assert len(yaml_values.keys()) > 0, \
        f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}'

    metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
    for op in supported:
        op_name = OperatorName.parse(op)
        m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
        metadata[m.operator] = m
    for op in supported_autograd:
        op_name = OperatorName.parse(op)
        m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
        metadata[m.operator] = m

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
    }

    def native_to_external(
            g: Union[NativeFunction, NativeFunctionsGroup]
    ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
        if isinstance(g, NativeFunction):
            f = g
            m = metadata.get(f.func.name, None)
            return ExternalBackendFunction(f, m)
        elif isinstance(g, NativeFunctionsGroup):
            return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
        else:
            assert_never(g)
    for op_name in metadata.keys():
        if op_name not in native_functions_map:
            raise AssertionError(f"Found an invalid operator name: {op_name}")
    return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]

def main() -> None:
    parser = argparse.ArgumentParser(description='Generate backend stub files')
    parser.add_argument(
        '-s',
        '--source_yaml',
        help='path to source yaml file containing operator external definitions')
    parser.add_argument(
        '-o', '--output_dir', help='output directory')
    parser.add_argument(
        '--dry_run', type=bool, default=False, help='output directory')
    options = parser.parse_args()

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)

    fm = make_file_manager(options.output_dir)

    native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    grouped_native_functions = get_grouped_native_functions(native_yaml_path)
    cpp_namespace, external_backend_functions = parse_backend_yaml(options.source_yaml, grouped_native_functions)

    native_functions = parse_native_yaml(native_yaml_path)

    selector = SelectiveBuilder.get_nop_selector()


    generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
    fm.write('aten_xla_type.h', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
    })

    fm.write('aten_xla_type_default.h', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        'dispatch_aten_fallback_declarations': list(concatMap(
            dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
        )),
    })

    fm.write('aten_xla_type_default.cpp', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        # TODO: after cpu fallbacks are moved to a boxed kernel,
        # merge registrations / definitions into RegisterDispatchKey
        'dispatch_aten_fallback_definitions': list(concatMap(
            dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
        )),
        'dispatch_registrations': list(concatMap(
            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
        )),
        'dispatch_autograd_registrations': list(concatMap(
            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
        )),
    })

if __name__ == '__main__':
    main()
